# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from hysop.tools.htypes import to_tuple, check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.io_utils import IO
from hysop.core.graph.graph import op_apply
from hysop.core.graph.computational_graph import ComputationalGraphOperator
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.parameters.tensor_parameter import TensorParameter
from hysop.backend.host.host_operator import HostOperatorBase
[docs]
class PlottingOperator(HostOperatorBase):
"""
Base operator for plotting.
"""
[docs]
@classmethod
def supports_mpi(cls):
return True
def __new__(
cls,
name=None,
dump_dir=None,
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
figsize=(30, 18),
visu_rank=0,
fig=None,
axes=None,
**kwds,
):
return super().__new__(cls, **kwds)
def __init__(
self,
name=None,
dump_dir=None,
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
figsize=(30, 18),
visu_rank=0,
fig=None,
axes=None,
**kwds,
):
import matplotlib
import matplotlib.pyplot as plt
check_instance(name, str)
check_instance(update_frequency, int, minval=0)
check_instance(save_frequency, int, minval=0)
check_instance(axes_shape, tuple, minsize=1, allow_none=True)
super().__init__(**kwds)
if (fig is None) ^ (axes is None):
msg = "figure and axes should be specified at the same time."
raise RuntimeError(msg)
dump_dir = first_not_None(dump_dir, IO.default_path())
imgpath = f"{dump_dir}/{name}.png"
if fig is None:
fig, axes = plt.subplots(*axes_shape, figsize=figsize)
fig.canvas.mpl_connect("key_press_event", self.on_key_press)
fig.canvas.mpl_connect("close_event", self.on_close)
axes = npw.asarray(axes).reshape(axes_shape)
self.fig = fig
self.axes = axes
self.update_frequency = update_frequency
self.save_frequency = save_frequency
self.imgpath = imgpath
self.should_draw = visu_rank == self.mpi_params.rank
self.running = True
self.plt = plt
self.update_ioparams = self.io_params.clone(
frequency=self.update_frequency, with_last=True
)
self.save_ioparams = self.io_params.clone(
frequency=self.save_frequency, with_last=True
)
[docs]
def draw(self):
if not self.running:
return
self.fig.canvas.draw()
self.fig.show()
self.plt.pause(0.001)
@op_apply
def apply(self, **kwds):
self._update(**kwds)
self._save(**kwds)
def _update(self, simulation, **kwds):
if self.update_ioparams.should_dump(simulation=simulation):
self.update(simulation=simulation, **kwds)
if self.should_draw:
self.draw()
def _save(self, simulation, **kwds):
if self.save_ioparams.should_dump(simulation=simulation):
self.save(simulation=simulation, **kwds)
[docs]
@abstractmethod
def update(self, **kwds):
pass
[docs]
def save(self, **kwds):
self.fig.savefig(self.imgpath, dpi=self.fig.dpi, bbox_inches="tight")
[docs]
def on_close(self, event):
self.running = False
[docs]
def on_key_press(self, event):
key = event.key
if key == "q":
self.plt.close(self.fig)
self.running = False
[docs]
class ParameterPlotter(PlottingOperator):
"""
Base operator to plot parameters during runtime.
"""
def __init__(
self, name, parameters, alloc_size=128, fig=None, axes=None, shape=None, **kwds
):
input_params = set()
if (fig is not None) and (axes is not None):
import matplotlib
custom_axes = True
axes_shape = None
check_instance(parameters, dict, keys=matplotlib.axes.Axes, values=dict)
for params in parameters.values():
check_instance(params, dict, keys=str, values=ScalarParameter)
input_params.update(set(params.values()))
else:
custom_axes = False
_parameters = {}
if isinstance(parameters, TensorParameter):
_parameters[0] = parameters
elif isinstance(parameters, (list, tuple)):
for i, p in enumerate(parameters):
_parameters[i] = parameters
elif isinstance(parameters, dict):
_parameters = parameters.copy()
else:
raise TypeError(type(parameters))
check_instance(
_parameters,
dict,
keys=(int, tuple, list),
values=(TensorParameter, list, tuple, dict),
)
parameters = {}
axes_shape = (1,) * 2
for pos, params in _parameters.items():
pos = to_tuple(pos)
pos = (2 - len(pos)) * (0,) + pos
check_instance(pos, tuple, values=int)
axes_shape = tuple(max(p0, p1 + 1) for (p0, p1) in zip(axes_shape, pos))
if isinstance(params, dict):
input_params.update({p.name: p for p in params.values()})
elif isinstance(params, TensorParameter):
input_params[params.name] = params
params = {params.name: params}
elif isinstance(params, (list, tuple)):
for p in params:
input_params[p.name] = p
params = {p.name: p for p in params}
else:
raise TypeError(type(params))
check_instance(params, dict, keys=str, values=TensorParameter)
_params = {}
for pname, p in params.items():
if isinstance(p, ScalarParameter):
_params[pname] = p
else:
for idx in npw.ndindex(*p.shape):
_pname = pname + f"_{idx}"
_p = p.view(idx)
_params[_pname] = _p
parameters[pos] = _params
super().__init__(
name=name,
input_params=input_params,
axes_shape=axes_shape,
axes=axes,
fig=fig,
**kwds,
)
self.custom_axes = custom_axes
data = {}
lines = {}
times = npw.empty(shape=(alloc_size,), dtype=npw.float32)
for pos, params in parameters.items():
params_data = {}
params_lines = {}
for pname, p in params.items():
pdata = npw.empty(shape=(alloc_size,), dtype=p.dtype)
pline = self.get_axes(pos).plot([], [], label=pname)[0]
params_data[p] = pdata
params_lines[p] = pline
data[pos] = params_data
lines[pos] = params_lines
self.fig.canvas.set_window_title("HySoP Parameter Plotter")
self.parameters = parameters
self.times = times
self.data = data
self.lines = lines
self.alloc_size = alloc_size
self.counter = 0
[docs]
def get_axes(self, pos):
axes = self.axes
if self.custom_axes:
return pos
else:
return axes[pos]
def __getitem__(self, i):
if self.custom_axes:
return self.axes[i]
else:
return self.axes.flatten()[i]
[docs]
def update(self, simulation, **kwds):
# expand memory if required
if self.counter + 1 > self.times.size:
times = npw.empty(shape=(2 * self.times.size,), dtype=self.times.dtype)
times[: self.times.size] = self.times
self.times = times
for pos, params in self.data.items():
for p, pdata in params.items():
new_pdata = npw.empty(shape=(2 * pdata.size,), dtype=pdata.dtype)
new_pdata[: pdata.size] = pdata
params[p] = new_pdata
times, data, lines = self.times, self.data, self.lines
times[self.counter] = simulation.t()
for pos, params in self.parameters.items():
for pname, p in params.items():
data[pos][p][self.counter] = p()
lines[pos][p].set_xdata(times[: self.counter])
lines[pos][p].set_ydata(data[pos][p][: self.counter])
self.counter += 1